/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.mix.server;

import hivemall.mix.MixMessage;
import hivemall.mix.MixMessage.MixEventName;
import hivemall.mix.store.PartialArgminKLD;
import hivemall.mix.store.PartialAverage;
import hivemall.mix.store.PartialResult;
import hivemall.mix.store.SessionObject;
import hivemall.mix.store.SessionStore;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;

import java.util.concurrent.ConcurrentMap;

import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;

@Sharable
public final class MixServerHandler extends SimpleChannelInboundHandler<MixMessage> {

    @Nonnull
    private final SessionStore sessionStore;
    private final int syncThreshold;
    private final float scale;

    public MixServerHandler(@Nonnull SessionStore sessionStore, @Nonnegative int syncThreshold,
            @Nonnegative float scale) {
        super();
        this.sessionStore = sessionStore;
        this.syncThreshold = syncThreshold;
        this.scale = scale;
    }

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, MixMessage msg) throws Exception {
        final MixEventName event = msg.getEvent();
        switch (event) {
            case average:
            case argminKLD: {
                SessionObject session = getSession(msg);
                PartialResult partial = getPartialResult(msg, session);
                mix(ctx, msg, partial, session);
                break;
            }
            case closeGroup: {
                closeGroup(msg);
                break;
            }
            default:
                throw new IllegalStateException("Unexpected event: " + event);
        }
    }

    private void closeGroup(@Nonnull MixMessage msg) {
        String groupId = msg.getGroupID();
        if (groupId == null) {
            return;
        }
        sessionStore.remove(groupId);
    }

    @Nonnull
    private SessionObject getSession(@Nonnull MixMessage msg) {
        String groupID = msg.getGroupID();
        if (groupID == null) {
            throw new IllegalStateException("JobID is not set in the request message");
        }
        SessionObject session = sessionStore.get(groupID);
        session.incrRequest();
        return session;
    }

    @Nonnull
    private PartialResult getPartialResult(@Nonnull MixMessage msg,
            @Nonnull SessionObject session) {
        final ConcurrentMap<Object, PartialResult> map = session.get();

        Object feature = msg.getFeature();
        PartialResult partial = map.get(feature);
        if (partial == null) {
            final MixEventName event = msg.getEvent();
            switch (event) {
                case average:
                    partial = new PartialAverage();
                    break;
                case argminKLD:
                    partial = new PartialArgminKLD();
                    break;
                default:
                    throw new IllegalStateException("Unexpected event: " + event);
            }
            PartialResult existing = map.putIfAbsent(feature, partial);
            if (existing != null) {
                partial = existing;
            }
        }
        return partial;
    }

    private void mix(final ChannelHandlerContext ctx, final MixMessage requestMsg,
            final PartialResult partial, final SessionObject session) {
        final MixEventName event = requestMsg.getEvent();
        final Object feature = requestMsg.getFeature();
        final float weight = requestMsg.getWeight();
        final float covar = requestMsg.getCovariance();
        final short localClock = requestMsg.getClock();
        final int deltaUpdates = requestMsg.getDeltaUpdates();
        final boolean cancelRequest = requestMsg.isCancelRequest();

        if (deltaUpdates <= 0) {
            throw new IllegalArgumentException("Illegal deltaUpdates received: " + deltaUpdates);
        }

        MixMessage responseMsg = null;
        try {
            partial.lock();

            if (cancelRequest) {
                partial.subtract(weight, covar, deltaUpdates, scale);
            } else {
                int diffClock = partial.diffClock(localClock);
                partial.add(weight, covar, deltaUpdates, scale);

                if (diffClock >= syncThreshold) {// sync model if clock DIFF is above threshold
                    float averagedWeight = partial.getWeight(scale);
                    float meanCovar = partial.getCovariance(scale);
                    short globalClock = partial.getClock();
                    responseMsg = new MixMessage(event, feature, averagedWeight, meanCovar,
                        globalClock, 0 /* deltaUpdates */);
                }
            }

        } finally {
            partial.unlock();
        }

        if (responseMsg != null) {
            session.incrResponse();
            ctx.writeAndFlush(responseMsg);
        }
    }

}
